#!/usr/bin/env python3

#==============================================================================
# Functionality
#==============================================================================
import pdb
import sys
import os
import re

# utility funcs, classes, etc go here.

def asserting(cond):
    if not cond:
        pdb.set_trace()
    assert(cond)

def has_stdin():
    return not sys.stdin.isatty()

def reg(pat, flags=0):
    return re.compile(pat, re.VERBOSE | flags)

#==============================================================================
# Cmdline
#==============================================================================
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, 
    description="""
TODO
""")
     
parser.add_argument('-v', '--verbose',
    action="store_true",
    help="verbose output" )
     
parser.add_argument('-s', '--steps',
    default=5000,
    help="Maximum number of steps between model checkpoints to keep" )

args = None

#==============================================================================
# Main
#==============================================================================
import tensorflow as tf
import re
import os

def bucket_checkpoints(path):
  entries = tf.io.gfile.listdir(path)
  results = {}
  for entry in entries:
    m = re.match('.*[.]ckpt-([0-9]+)[.]', entry)
    if m is not None:
      step = int(m.groups()[0])
      if step not in results:
        results[step] = []
      results[step].append(os.path.join(path, entry))
  return results

def checkpoints_to_keep(path, keep_every):
  checkpoints = bucket_checkpoints(path)
  if len(checkpoints) > 0:
    def get(n):
      best = None
      for k, v in checkpoints.items():
        if k >= n and k < n + keep_every:
          if best is None or k > best:
            best = k
      if best is not None:
        return checkpoints[best]
      else:
        return []
    upto = max(checkpoints.keys()) + keep_every
    results = []
    for i in range(0, upto, keep_every):
      for result in get(i):
        if result not in results:
          results.append(result)
    return results

def checkpoints_to_remove(path, keep_every):
  checkpoints = bucket_checkpoints(path)
  keep = checkpoints_to_keep(path, keep_every)
  assert len(checkpoints) <= 0 or len(keep) > 0
  for k, vs in list(sorted(checkpoints.items()))[0:-5]:
    for v in vs:
      if v not in keep:
        yield v

def run():
    if args.verbose:
        print(args)
    if len(args.args) <= 0 and not has_stdin():
        # if there were no args and there was no input, prompt user.
        print('Enter input (press Ctrl-D when done):')
    if len(args.args) <= 0 or has_stdin():
        indata = sys.stdin.read()
        args.args.extend(indata.splitlines())
    # for each arg on cmdline...
    for path in args.args:
      ckpts = list(checkpoints_to_remove(path, int(args.steps)))
      for v in ckpts:
        print(v)

def main():
    try:
        global args
        if not args:
            args, leftovers = parser.parse_known_args()
            args.args = leftovers
        return run()
    except IOError:
        # http://stackoverflow.com/questions/15793886/how-to-avoid-a-broken-pipe-error-when-printing-a-large-amount-of-formatted-data
        try:
            sys.stdout.close()
        except IOError:
            pass
        try:
            sys.stderr.close()
        except IOError:
            pass

if __name__ == "__main__":
    main()

